Skip to content

Conversation

@wingertge
Copy link
Contributor

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Migrates to changes in tracel-ai/cubecl#1127 and tracel-ai/cubek#51

Changes

Migrates cubecl kernels and fusion to usize indexing

Testing

The test suite runs successfully, though 64-bit indexing is not yet enabled.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in burn LGTM! Pending rev updates once linked PRs are merged.

burn_tensor::bf16,
"../autodiff/mod.rs",
["vulkan", "metal"] // ["cuda", "rocm"] TODO
["metal"] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I added bf16 for vulkan but clearly the tests don't pass 😅 thanks for fixing.

[Unrelated to this PR]

I vaguely remember adding it when refactoring the tests since it was a supported global type for vulkan in cubecl. Also reflected by B::supports_dtype:

#[test]
fn should_support_dtypes() {
type B = Wgpu;
let device = Default::default();
assert!(B::supports_dtype(&device, DType::F32));
assert!(B::supports_dtype(&device, DType::I64));
assert!(B::supports_dtype(&device, DType::I32));
assert!(B::supports_dtype(&device, DType::U64));
assert!(B::supports_dtype(&device, DType::U32));
assert!(B::supports_dtype(
&device,
DType::QFloat(CubeTensor::<WgpuRuntime>::default_scheme())
));
// Registered as supported type but we don't actually use it?
assert!(B::supports_dtype(&device, DType::Bool));
#[cfg(feature = "vulkan")]
{
assert!(B::supports_dtype(&device, DType::F16));
assert!(B::supports_dtype(&device, DType::BF16));
assert!(B::supports_dtype(&device, DType::I16));
assert!(B::supports_dtype(&device, DType::I8));
assert!(B::supports_dtype(&device, DType::U16));
assert!(B::supports_dtype(&device, DType::U8));
assert!(!B::supports_dtype(&device, DType::F64));
assert!(!B::supports_dtype(&device, DType::Flex32));
}

maybe we should have a better way to represent the actual supported types?

fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
let client = R::client(device);
let ty: StorageType = dtype.into();
client.properties().supports_type(ty.elem_type())
}

That way, tested dtypes can actually reflected supported dtypes.

Copy link
Contributor Author

@wingertge wingertge Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The supports_type method on the CubeCL side only checks if a type is supported in any way (in this case, it's supported for conversion, as a type for buffers, and for dot product on Intel, along with tensor core instructions). It's kinda tough though because there's no good way to express that in just a single boolean (hence why the TypeUsage enum set exists in CubeCL).

This is how it's registered for Vulkan

if let Some(bfloat16) = ext_feat.bfloat16 {
    if bfloat16.shader_b_float16_type == TRUE {
        register(
            ElemType::Float(FloatKind::BF16).into(),
            TypeUsage::Conversion | TypeUsage::Buffer,
        );
    }
    if bfloat16.shader_b_float16_dot_product == TRUE {
        register(
            ElemType::Float(FloatKind::BF16).into(),
            TypeUsage::DotProduct.into(),
        );
    }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's supported for matmul and casting, but none of the other ops.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The supports_type method on the CubeCL side only checks if a type is supported in any way

Yeah and for the first draft I simply mirrored that for the backends, but I think it should be refined.

It's kinda tough though because there's no good way to express that in just a single boolean (hence why the TypeUsage enum set exists in CubeCL).

That's a good point. It's still useful to query backend supported types for burn, so maybe we should also define an enum similar to TypeUsage? (without atomics, and perhaps consolidate conversion / buffer into "storage" variant or similar).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless, that's something for a separate issue/PR I think.

Copy link
Member

@laggui laggui Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes entirely, that's why I prefaced my initial comment with [Unrelated to this PR] 😄

Just wanted to get your thoughts since this was related to the bf16 change

@wingertge wingertge requested a review from laggui January 8, 2026 16:46
@codecov
Copy link

codecov bot commented Jan 8, 2026

Codecov Report

❌ Patch coverage is 0% with 831 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.87%. Comparing base (d682723) to head (e0d9252).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-cubecl-fusion/src/engine/codegen/io.rs 0.00% 89 Missing ⚠️
.../cube/connected_components/hardware_accelerated.rs 0.00% 77 Missing ⚠️
...es/burn-cubecl-fusion/src/engine/codegen/kernel.rs 0.00% 54 Missing ⚠️
crates/burn-cubecl-fusion/src/optim/matmul/args.rs 0.00% 39 Missing ⚠️
...ecl-fusion/src/engine/launch/vectorization/base.rs 0.00% 35 Missing ⚠️
...c/backends/cube/connected_components/prefix_sum.rs 0.00% 34 Missing ⚠️
crates/burn-cubecl/src/kernel/conv/direct.rs 0.00% 28 Missing ⚠️
...rates/burn-cubecl/src/kernel/index/slice_assign.rs 0.00% 27 Missing ⚠️
...ates/burn-cubecl-fusion/src/engine/codegen/view.rs 0.00% 26 Missing ⚠️
...-cubecl/src/kernel/conv/deform_conv_transpose2d.rs 0.00% 26 Missing ⚠️
... and 53 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4273      +/-   ##
==========================================
+ Coverage   68.85%   68.87%   +0.02%     
==========================================
  Files        1405     1405              
  Lines      167686   167607      -79     
==========================================
- Hits       115456   115442      -14     
+ Misses      52230    52165      -65     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@laggui laggui merged commit 7b32614 into tracel-ai:main Jan 8, 2026
10 checks passed
@wingertge wingertge deleted the feat/usize branch January 8, 2026 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants